{
 "cells": [
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "批量归一化\n",
    "batch normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "from mxnet import autograd, gluon, init, nd\n",
    "from mxnet.gluon import nn\n",
    "\n",
    "def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
    "    if not autograd.is_training():\n",
    "        X_hat = (X - moving_mean) / nd.sqrt(moving_var + eps)\n",
    "    else:\n",
    "        assert len(X.shape) in (2, 4)\n",
    "        if len(X.shape) == 2:\n",
    "            mean = X.mean(axis=0)\n",
    "            var = ((X - mean)**2).mean(axis=0)\n",
    "        else:\n",
    "            mean = X.mean(axis=(0, 2, 3), keepdims=True)\n",
    "            var = ((X - mean)**2).mean(axis=(0, 2, 3), keepdims=True)\n",
    "        X_hat = (X - mean) / nd.sqrt(var + eps)\n",
    "        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean\n",
    "        moving_var = momentum * moving_var + (1.0 - momentum) * var\n",
    "    Y = gamma * X_hat + beta\n",
    "    return Y, moving_mean, moving_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BatchNorm(nn.Block):\n",
    "    def __init__(self, num_features, num_dims, **kwargs):\n",
    "        super(BatchNorm, self).__init__(**kwargs)\n",
    "        if num_dims == 2:\n",
    "            shape = (1, num_features)\n",
    "        else:\n",
    "            shape = (1, num_features, 1, 1)\n",
    "        self.gamma = self.params.get('gamma', shape=shape, init=init.One())\n",
    "        self.beta = self.params.get('beta', shape=shape, init=init.Zero())\n",
    "        self.moving_mean = nd.zeros(shape)\n",
    "        self.moving_var = nd.zeros(shape)\n",
    "        \n",
    "    def forward(self, X):\n",
    "        #如果X不在内存上，将moving_mean和moving_var复制到X所在的显存上\n",
    "        if self.moving_mean.context != X.context:\n",
    "            self.moving_mean = self.moving_mean.copyto(X.context)\n",
    "            self.moving_var = self.moving_var.copyto(X.context)\n",
    "        #保存更新过的moving_mean和moving_var\n",
    "        Y, self.moving_mean, self.moving_var = batch_norm(\n",
    "            X, self.gamma.data(), self.beta.data(), self.moving_mean,\n",
    "            self.moving_var, eps=1e-5, momentum=0.9)\n",
    "        return Y"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "使用批量归一化层的LeNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential()\n",
    "net.add(nn.Conv2D(6, kernel_size=5),\n",
    "       BatchNorm(6, num_dims=4),\n",
    "       nn.Activation('sigmoid'),\n",
    "       nn.MaxPool2D(pool_size=2, strides=2),\n",
    "       nn.Conv2D(16, kernel_size=5),\n",
    "       BatchNorm(16, num_dims=4),\n",
    "       nn.Activation('sigmoid'),\n",
    "       nn.MaxPool2D(pool_size=2, strides=2),\n",
    "       nn.Dense(120),\n",
    "       BatchNorm(120, num_dims=2),\n",
    "       nn.Activation('sigmoid'),\n",
    "       nn.Dense(84),\n",
    "       BatchNorm(84, num_dims=2),\n",
    "       nn.Activation('sigmoid'),\n",
    "       nn.Dense(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on cpu(0)\n",
      "epoch 1, loss 0.6602, train acc 0.766, test acc 0.836, time 148.9 sec\n",
      "epoch 2, loss 0.3975, train acc 0.856, test acc 0.851, time 150.4 sec\n",
      "epoch 3, loss 0.3458, train acc 0.874, test acc 0.881, time 150.6 sec\n",
      "epoch 4, loss 0.3202, train acc 0.885, test acc 0.853, time 148.6 sec\n",
      "epoch 5, loss 0.3000, train acc 0.892, test acc 0.840, time 146.7 sec\n"
     ]
    }
   ],
   "source": [
    "#训练修改后的LeNet模型\n",
    "lr, num_epochs, batch_size, ctx = 1.0, 5, 256, d2l.try_gpu()\n",
    "net.initialize(ctx=ctx, init=init.Xavier())\n",
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(\n",
       " [1.9158536 1.6384647 1.6526939 1.0363525 1.5951579 1.8983926]\n",
       " <NDArray 6 @cpu(0)>, \n",
       " [ 1.1872159  -0.12765469  0.28762192  0.18677962 -0.3699979  -1.9960198 ]\n",
       " <NDArray 6 @cpu(0)>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[1].gamma.data().reshape((-1,)), net[1].beta.data().reshape((-1,))"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "简洁实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential()\n",
    "net.add(nn.Conv2D(6, kernel_size=5),\n",
    "       nn.BatchNorm(),\n",
    "       nn.Activation('sigmoid'),\n",
    "       nn.MaxPool2D(pool_size=2, strides=2),\n",
    "       nn.Conv2D(16, kernel_size=5),\n",
    "       nn.BatchNorm(),\n",
    "       nn.Activation('sigmoid'),\n",
    "       nn.MaxPool2D(pool_size=2, strides=2),\n",
    "       nn.Dense(120),\n",
    "       nn.BatchNorm(),\n",
    "       nn.Activation('sigmoid'),\n",
    "       nn.Dense(84),\n",
    "       nn.BatchNorm(),\n",
    "       nn.Activation('sigmoid'),\n",
    "       nn.Dense(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on cpu(0)\n",
      "epoch 1, loss 0.6440, train acc 0.771, test acc 0.833, time 21.9 sec\n",
      "epoch 2, loss 0.3907, train acc 0.859, test acc 0.868, time 20.9 sec\n",
      "epoch 3, loss 0.3461, train acc 0.875, test acc 0.869, time 20.8 sec\n",
      "epoch 4, loss 0.3189, train acc 0.883, test acc 0.867, time 21.0 sec\n",
      "epoch 5, loss 0.2980, train acc 0.892, test acc 0.883, time 20.9 sec\n"
     ]
    }
   ],
   "source": [
    "net.initialize(ctx=ctx, init=init.Xavier())\n",
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
